import numpy as np

# MEANS = np.random.randn(1, 1, 1000) * 10
coeffs = [-0.71347752, -0.6062312 ,  0.66062073, -0.61213017, -0.82631877,
          0.75908621, -0.36274866,  0.84383333, -0.37908235,  0.46539368,
          0.98675292,  0.20388434, -0.78718797, -0.74316581,  0.11539548,
          0.35559215,  0.32584524,  0.58590426,  0.95284985, -0.77999937,
          -0.660806  , -0.35542737,  0.93532822,  0.75049373, -0.27825577]

def search(x, y, minimize=True):
    # x - (num_points x (length - 1))
    # y - (num_points)

    y = np.reshape(y, [-1,1])

    deltas = np.abs(x - y) # (num_points x (length - 1))
    if minimize:
        indices = np.argmin(deltas, 1)
    else:
        indices = np.argmax(deltas, 1)

    return indices

def search_by_indices(data, time, index, minimize=True):
    curr = data[:, time, index]
    search_data = data[:,:,index].copy()
    search_data[:, time] = float('inf')

    winners = search(search_data, curr, minimize)

    return winners

def retrieve(data, winners, index):
    return data[range(data.shape[0]), winners, index]

def sum(a, b):
    return a+b

def subtract(a,b):
    return a-b

def max(a, b):
    return np.maximum(a,b)

def rule(data, time, search1, search2, retrieve1, retrieve2, func):
    s1 = search_by_indices(data, time, search1)
    out1 = retrieve(data, s1, retrieve1)

    s2 = search_by_indices(data, time, search2)
    out2 = retrieve(data, s2, retrieve2)

    return func(out1, out2), s1, s2

def onehot(task, num_points, length, v_s, v_p):
    task_onehot = np.zeros((task.size, v_p))
    task_onehot[np.arange(task.size), task] = 1.
    task_onehot = np.reshape(task_onehot, (num_points, length, v_s, v_p))
    return task_onehot

def dataset(num_points, length, v_s, v_p, cff=False):
    data = np.random.randn(num_points, length, v_p+v_s)
    task = np.random.choice(v_p, [num_points * length * v_s])
    coeff = np.reshape(coeffs[:v_s], [1,1,v_s])

    task_onehot = onehot(task, num_points, length, v_s, v_p)

    searches = np.zeros([num_points, length, length])
    retrievals = np.zeros([num_points, length, v_s, v_p])

    for i in range(length):
        for search in range(v_s):
            s = search_by_indices(data, i, search)
            searches[range(data.shape[0]), i, s] = 1.
            for r in range(v_p):
                retrievals[:,i,search,r] = retrieve(data, s, v_s + r)

    if cff:
        labels = np.sum(coeff * np.sum(task_onehot * retrievals, axis=-1), axis=-1)
    else:
        labels = np.sum(np.sum(task_onehot * retrievals, axis=-1), axis=-1)

    inp = np.concatenate((data, np.reshape(task_onehot, (num_points, length, v_s * v_p))), axis=-1)

    return inp, labels, searches